import torch
from tqdm import tqdm
import torch.nn as nn

def accuracy(model, loader, device, show=True, transform=None, return_losses=False):
    '''Compute Accuracy'''
    correct, total = 0, 0
    all_losses = []
    
    criterion = nn.CrossEntropyLoss(reduction="none")

    model.eval()
    with torch.no_grad():
        t = tqdm(loader) if show else loader
        for images, target in t:
            images = images.to(device)
            if transform:
                images=transform(images)
            target = target.to(device)
            outputs = model(images).to(device)
            correct += (outputs.argmax(1) == target).sum().item()
            total += target.numel()
            acc = correct / total

            if return_losses:
                losses = criterion(outputs, target).cpu().detach().numpy()
                for l in losses:
                    all_losses.append(l)

            if show:
                t.set_description(f'test acc: {acc*100:.2f}%')

        if return_losses:
            return acc*100, all_losses
        else:
            return acc*100